import os
import pickle
import numpy as np
import time
import torch
import visdom
import shutil

Tensor = torch.DoubleTensor
torch.set_default_tensor_type('torch.DoubleTensor')

from model import Discriminator, Generator
from helpers import *


def printlog(line):
    print(line)
    with open(save_path+'log.txt', 'a') as file:
        file.write(line+'\n')


trial_id = 600
seed = 500
clip = 10
G_lr = 1e-5
D_lr = 1e-3

pretrain_epochs = 20
gan_num_iters = 80000
pretrain_dis_iters = 2000
batch_size = 64
save_every = 200
draw_every = 200

use_gpu = False
device = 'cuda:0' if use_gpu else 'cpu'

# Set manual seed
np.random.seed(seed)
torch.manual_seed(seed)

if use_gpu:
    torch.cuda.manual_seed_all(seed)

params = {
    'G_params' : {
        'x_dim' : 10,
        'z_dim' : 64,
        'h_dim' : 200, 
        'n_layers' : 2
    },
    'D_params' : {
        'x_dim' : 10,
        'h_dim' : 128, 
        'n_layers' : 2
    }
}

# Create save path and saving parameters
save_path = 'saved/%03d/' % trial_id
if not os.path.exists(save_path):
    os.makedirs(save_path)
    os.makedirs(save_path+'model/')
pickle.dump(params, open(save_path+'/params.p', 'wb'), protocol=2)

# Load models
G = Generator(params['G_params']).to(device)
print("Generator has {} trainable parameters".format(num_trainable_params(G)))

D = Discriminator(params['D_params']).to(device)
print("Discriminator has {} trainable parameters".format(num_trainable_params(D)))

# Create optimizers
G_opt = torch.optim.Adam(G.parameters(), lr=G_lr)
D_opt = torch.optim.Adam(D.parameters(), lr=D_lr)

# Load data
train_data = torch.Tensor(pickle.load(open('bball_data/data/basketball_train.p', 'rb'))).transpose(0, 1)[:, :-1, :]
test_data = torch.Tensor(pickle.load(open('bball_data/data/basketball_eval.p', 'rb'))).transpose(0, 1)[:, :-1, :]
train_data = train_data.to(device)
test_data = test_data.to(device)

# visdom setup
vis = visdom.Visdom(env = 'grui')
exp_p = []
win_exp_p = None
mod_p = []
win_mod_p = None
win_path_length = None
win_out_of_bound = None
win_step_change = None

# create and clean img folder
if os.path.exists('imgs'):
    shutil.rmtree('imgs')
if not os.path.exists('imgs'):
    os.makedirs('imgs')

##########################
### PRETRAIN GENERATOR ###
##########################

best_test_loss = 0

for e in range(pretrain_epochs):
    epoch = e+1
    printlog("EPOCH (pretrain) [{}/{}]".format(epoch, pretrain_epochs))

    start_time = time.time()

    train_loss = run_epoch(True, G, train_data, clip, G_opt, batch_size)
    printlog('Train:\t' + str(train_loss))

    test_loss = run_epoch(False, G, test_data, clip, G_opt, batch_size)
    printlog('Test:\t' + str(test_loss))

    epoch_time = time.time() - start_time
    printlog('Time:\t {:.3f}'.format(epoch_time))

    # Save best model on test set
    if epoch == 1 or test_loss < best_test_loss:    
        best_test_loss = test_loss
        filename = save_path+'model/G_best_pretrain.pth'
        torch.save(G.state_dict(), filename)
        printlog('Best model at epoch '+str(epoch))
    
printlog('End of Pretrain, Best Test Loss: {:.4f}'.format(best_test_loss))

####################
### GAN TRAINING ###
####################

# Load the best pretrained policy
best_state_dict = torch.load(save_path+'model/G_best_pretrain.pth', map_location='cpu')
#best_state_dict = torch.load(save_path+'model/G.pth', map_location='cpu')
G.load_state_dict(best_state_dict)

# pretrain discriminator
for e in range(pretrain_dis_iters):
    real_probs_val, fake_probs_val = pretrain_dis_iter(G, D, D_opt, train_data, clip, batch_size)
    if fake_probs_val < 0.3:
        break

if pretrain_dis_iters > 200:
    filename = save_path+'model/G_pretrain.pth'
    torch.save(G.state_dict(), filename)
    filename = save_path+'model/D_pretrain.pth'
    torch.save(D.state_dict(), filename)

load_prev_dis = False
if load_prev_dis:
    dis_dict = torch.load(save_path+'model/D.pth')
    D.load_state_dict(dis_dict)

for e in range(gan_num_iters):
    printlog("EPOCH (GAN) [{}/{}]".format(e, gan_num_iters))

    start_time = time.time()

    real_probs_val, fake_probs_val = train_gan_iter(G, D, G_opt, D_opt, train_data, clip, batch_size)

    update = 'append' if e > 0 else None
    win_exp_p = vis.line(X = np.array([e]), \
                         Y = np.column_stack((real_probs_val, fake_probs_val)), \
                         win = win_exp_p, update = update, \
                         opts=dict(legend=['expert_prob', 'model_prob'], title="training curve probs"))

    # Generate and plot some samples
    if e % draw_every == 0:
        samples = G.generate(10, T=49, device=device)
        mod_stats = draw_and_stats(samples.data, 'GAN_stage', e, draw=True, compute_stats=True)
        win_path_length = vis.line(X = np.array([e // draw_every]), \
            Y = np.array([mod_stats['ave_length']]), \
            win = win_path_length, update = update, opts=dict(legend=['model'], title="average path length"))
        win_out_of_bound = vis.line(X = np.array([e // draw_every]), \
            Y = np.array([mod_stats['ave_out_of_bound']]), \
            win = win_out_of_bound, update = update, opts=dict(legend=['model'], title="average out of bound rate"))
        win_step_change = vis.line(X = np.array([e // draw_every]), \
            Y = np.array([mod_stats['ave_change_step_size']]), \
            win = win_step_change, update = update, opts=dict(legend=['model'], title="average step size change"))

    epoch_time = time.time() - start_time
    printlog('Time:\t {:.3f}'.format(epoch_time))

    if e % save_every == 0:
        filename = save_path+'model/G.pth'
        torch.save(G.state_dict(), filename)
        filename = save_path+'model/D.pth'
        torch.save(D.state_dict(), filename)

filename = save_path+'model/G_final.pth'
torch.save(G.state_dict(), filename)
filename = save_path+'model/D_final.pth'
torch.save(D.state_dict(), filename)

